{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from jax import numpy as jnp\n",
    "import numpy as onp\n",
    "import jax\n",
    "import tensorflow.compat.v2 as tf\n",
    "import argparse\n",
    "import time\n",
    "import tqdm\n",
    "from collections import OrderedDict\n",
    "\n",
    "from bnn_hmc.utils import data_utils\n",
    "from bnn_hmc.utils import models\n",
    "from bnn_hmc.utils import metrics\n",
    "from bnn_hmc.utils import losses\n",
    "from bnn_hmc.utils import checkpoint_utils\n",
    "from bnn_hmc.utils import cmd_args_utils\n",
    "from bnn_hmc.utils import logging_utils\n",
    "from bnn_hmc.utils import train_utils\n",
    "from bnn_hmc.utils import precision_utils\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    }
   ],
   "source": [
    "dtype = jnp.float32\n",
    "train_set, test_set, task, data_info = data_utils.make_ds_pmap_fullbatch(\n",
    "    \"mnist\", dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train, y_train = train_set\n",
    "x_train, y_train = x_train[0], y_train[0]\n",
    "x_train, y_train = onp.asarray(x_train), onp.asarray(y_train)\n",
    "\n",
    "x_test, y_test = test_set\n",
    "x_test, y_test = x_test[0], y_test[0]\n",
    "x_test, y_test = onp.asarray(x_test), onp.asarray(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((60000, 28, 28, 1), (60000,))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train.shape, y_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "onp.savetxt(\"mnist_train_x.csv\", x_train.reshape((60000, -1)))\n",
    "onp.savetxt(\"mnist_train_y.csv\", y_train)\n",
    "\n",
    "onp.savetxt(\"mnist_test_x.csv\", x_test.reshape((10000, -1)))\n",
    "onp.savetxt(\"mnist_test_y.csv\", y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CIFAR-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "dtype = jnp.float32\n",
    "train_set, test_set, task, data_info = data_utils.make_ds_pmap_fullbatch(\n",
    "    \"cifar10\", dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((50000, 32, 32, 3), (50000,))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train, y_train = train_set\n",
    "x_train, y_train = x_train[0], y_train[0]\n",
    "x_train, y_train = onp.asarray(x_train), onp.asarray(y_train)\n",
    "\n",
    "x_test, y_test = test_set\n",
    "x_test, y_test = x_test[0], y_test[0]\n",
    "x_test, y_test = onp.asarray(x_test), onp.asarray(y_test)\n",
    "\n",
    "x_train.shape, y_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "onp.savetxt(\"cifar10_train_x.csv\", x_train.reshape((50000, -1)))\n",
    "onp.savetxt(\"cifar10_train_y.csv\", y_train)\n",
    "\n",
    "onp.savetxt(\"cifar10_test_x.csv\", x_test.reshape((10000, -1)))\n",
    "onp.savetxt(\"cifar10_test_y.csv\", y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py38",
   "language": "python",
   "name": "py38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
